{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "04203.ipynb",
      "provenance": [],
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/yananma/5_programs_per_day/blob/master/04203.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EakXiORfCoCj",
        "colab_type": "text"
      },
      "source": [
        "## 4.5 读取和存储"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "uWnE0fynC4jw",
        "colab_type": "text"
      },
      "source": [
        "### 4.5.1 读写 Tensor"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "xc6ZhV2hC9OT",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import torch\n",
        "from torch import nn \n",
        "\n",
        "x = torch.ones(3)\n",
        "torch.save(x, 'x.pt')"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "cls9hSkbDH1U",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "38b68a97-4aec-4e32-9849-400fc08b9787"
      },
      "source": [
        "x2 = torch.load('x.pt')\n",
        "x2"
      ],
      "execution_count": 3,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "tensor([1., 1., 1.])"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 3
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "CfjfBkhrDMK1",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "87c5e718-1c3c-4573-b8f0-bd46527ece7a"
      },
      "source": [
        "y = torch.zeros(4)\n",
        "torch.save([x, y], 'xy.pt')\n",
        "xy_list = torch.load('xy.pt')\n",
        "xy_list"
      ],
      "execution_count": 4,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "[tensor([1., 1., 1.]), tensor([0., 0., 0., 0.])]"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 4
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "__fcl9pIDbgr",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "9c1525f1-581c-4c13-fa98-bbe186cf9407"
      },
      "source": [
        "torch.save({'x':x, 'y':y}, 'xy_dict.pt')\n",
        "xy = torch.load('xy_dict.pt')\n",
        "xy"
      ],
      "execution_count": 5,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "{'x': tensor([1., 1., 1.]), 'y': tensor([0., 0., 0., 0.])}"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 5
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_e4IO18TDu0N",
        "colab_type": "text"
      },
      "source": [
        "### 4.5.2 读写模型"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ULDcN5alD0NQ",
        "colab_type": "text"
      },
      "source": [
        "#### 1. state_dict"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "rh4uJYsND6YI",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 101
        },
        "outputId": "e38bb958-3257-4243-9064-03b4454f70cf"
      },
      "source": [
        "class MLP(nn.Module):\n",
        "    def __init__(self):\n",
        "        super(MLP, self).__init__()\n",
        "        self.hidden = nn.Linear(3, 2)\n",
        "        self.act = nn.ReLU()\n",
        "        self.output = nn.Linear(2, 1)\n",
        "\n",
        "    def forward(self, x):\n",
        "        a = self.act(self.hidden(x))\n",
        "        return self.output(a)\n",
        "\n",
        "net = MLP()\n",
        "net.state_dict()"
      ],
      "execution_count": 6,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "OrderedDict([('hidden.weight', tensor([[-0.1676,  0.1718,  0.1201],\n",
              "                      [-0.2257,  0.2118,  0.1144]])),\n",
              "             ('hidden.bias', tensor([-0.1384,  0.2647])),\n",
              "             ('output.weight', tensor([[ 0.3813, -0.5593]])),\n",
              "             ('output.bias', tensor([-0.1984]))])"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 6
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "jFyeAprVEU0r",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 185
        },
        "outputId": "54c02158-26ae-4587-ec51-4768e18364d6"
      },
      "source": [
        "optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)\n",
        "optimizer.state_dict()"
      ],
      "execution_count": 7,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "{'param_groups': [{'dampening': 0,\n",
              "   'lr': 0.001,\n",
              "   'momentum': 0.9,\n",
              "   'nesterov': False,\n",
              "   'params': [140599130499400,\n",
              "    140599130498464,\n",
              "    140599131621992,\n",
              "    140599130497744],\n",
              "   'weight_decay': 0}],\n",
              " 'state': {}}"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 7
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eFnREG3bEgxQ",
        "colab_type": "text"
      },
      "source": [
        "#### 2. 保存和加载模型"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "-9J9L30DE61_",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 50
        },
        "outputId": "5a0b97c0-aaef-43e2-d770-215d17f3afc5"
      },
      "source": [
        "X = torch.randn(2, 3)\n",
        "Y = net(X)\n",
        "\n",
        "PATH = \"./net.pt\"\n",
        "torch.save(net.state_dict(), PATH)\n",
        "\n",
        "net2 = MLP()\n",
        "net2.load_state_dict(torch.load(PATH))\n",
        "Y2 = net2(X)\n",
        "Y2 == Y"
      ],
      "execution_count": 8,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "tensor([[True],\n",
              "        [True]])"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 8
        }
      ]
    }
  ]
}